import os
import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset


class SFEWDataset(Dataset):
    """
    SFEW Dataset Adapter for Torch
    """

    def __init__(self, phase, base_directory, transform=None):
        """
        :param phase: train or test
        :param base_directory: Path to the basic directory
        :param transform: Transformer that will be applied on the image data
        """

        if phase == 'train':
            label_path = os.path.join(base_directory, 'train.csv')
        else:
            label_path = os.path.join(base_directory, 'val.csv')
        dataset = pd.read_csv(label_path, sep=',', header=None)

        self.labelMap = {'Angry': 5,
                         'Disgust': 2,
                         'Fear': 1,
                         'Happy': 3,
                         'Neutral': 6,
                         'Sad': 4,
                         'Surprise': 0}
        self.labels = dataset[0].to_list()
        self.labels = list(map(lambda x: self.labelMap[x], self.labels))

        self.img_paths = dataset[1].to_list()
        self.img_paths = list(map(lambda x: os.path.join(base_directory, x), self.img_paths))

        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]

        image = cv2.imread(self.img_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image)
        if self.transform:
            image = self.transform(image)

        return image, label
